// Copyright 2024 The LUCI Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//      http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package botapi

import (
	"context"
	"time"

	"google.golang.org/grpc/codes"
	"google.golang.org/grpc/status"
	"google.golang.org/protobuf/types/known/timestamppb"

	"go.chromium.org/luci/common/clock"
	"go.chromium.org/luci/common/errors"
	"go.chromium.org/luci/common/logging"
	"go.chromium.org/luci/gae/service/datastore"

	internalspb "go.chromium.org/luci/swarming/proto/internals"
	"go.chromium.org/luci/swarming/server/botinfo"
	"go.chromium.org/luci/swarming/server/botsession"
	"go.chromium.org/luci/swarming/server/botsrv"
	"go.chromium.org/luci/swarming/server/model"
	"go.chromium.org/luci/swarming/server/tasks"
)

const botLastSeenUpdateInterval = 20 * time.Second

// TaskUpdateRequest is sent by the bot.
type TaskUpdateRequest struct {
	// Session is a serialized Swarming Bot Session proto.
	Session []byte `json:"session"`

	// TaskID is the TaskResultSummary packed key of the task to update.
	//
	// Required.
	TaskID string `json:"task_id"`

	// RequestUUID is used to skip reporting duplicate events on retries.
	//
	// Generated by the client (usually an UUID4 string). Optional.
	RequestUUID string `json:"request_uuid,omitempty"`

	// Performance stats
	// BotOverhead is the total overhead in seconds, summing overhead from all
	// sections defined below.
	BotOverhead *float64 `json:"bot_overhead,omitempty"`
	// CacheTrimStats is stats of cache trimming before the dependency installations.
	CacheTrimStats model.OperationStats `json:"cache_trim_stats,omitempty"`
	// CIPDStats is stats of installing CIPD packages before the task.
	CIPDStats model.OperationStats `json:"cipd_stats,omitempty"`
	// IsolatedStats is stats of CAS operations, including CAS dependencies
	// download before the task and CAS uploading operation after the task.
	IsolatedStats IsolatedStats `json:"isolated_stats,omitempty"`
	// NamedCachesInstall is stats of named cache operations, including mounting
	// before the task and unmounting after the task.
	NamedCachesStats NamedCachesStats `json:"named_caches_stats,omitempty"`
	// CleanupStats is stats of work directory cleanup operation after the task.
	CleanupStats model.OperationStats `json:"cleanup_stats,omitempty"`

	// Task updates
	// CASOutputRoot is the digest of the output root uploaded to RBE-CAS.
	CASOutputRoot model.CASReference `json:"cas_output_root,omitempty"`
	// CIPDPins is resolved versions of all the CIPD packages used in the task.
	CIPDPins model.CIPDInput `json:"cipd_pins,omitempty"`
	// CostUSD is an approximate bot time cost spent executing this task.
	CostUSD float64 `json:"cost_usd,omitempty"`
	// Duration is the time spent in seconds for this task, excluding overheads.
	Duration *float64 `json:"duration,omitempty"`
	// ExitCode is the task process exit code for tasks in COMPLETED state.
	ExitCode *int64 `json:"exit_code,omitempty"`
	// HardTimeout is a bool on whether a hard timeout occurred.
	HardTimeout bool `json:"hard_timeout,omitempty"`
	// IOTimeout is a bool on whether an I/O timeout occurred.
	IOTimeout bool `json:"io_timeout,omitempty"`
	// Output is the data to append to the stdout content for the task.
	Output []byte `json:"output,omitempty"`
	// OutputChunkStart is the index of output in the stdout stream.
	OutputChunkStart int64 `json:"output_chunk_start,omitempty"`
	// Canceled is a bool on whether the task has been canceled at set up stage.
	//
	// It is set only for the case where the server receives a request to cancel
	// the task after a bot claims it but before the bot actually starts to run
	// it.
	// * If the server receives the cancel request before the task is claimed,
	//   the task will be ended with "CANCELED" state right away and no bot can
	//   claim it.
	// * If the server receives the cancel request after the bot starts to run
	//   the task, the bot will end the task gracefully and send a task update
	//   with `duration` set when it's done. The task will then be ended with
	//   "KILLED" state.
	//
	// Task in this case will be ended with "CANCELED" state as well.
	Canceled bool `json:"canceled,omitempty"`
}

// IsolatedStats is stats of CAS operations.
type IsolatedStats struct {
	// Download is stats of CAS dependencies download before the task.
	Download model.CASOperationStats `json:"download,omitempty"`
	// Upload is stats of CAS uploading operation after the task.
	Upload model.CASOperationStats `json:"upload,omitempty"`
}

// NamedCachesStats is stats of named cache operations.
type NamedCachesStats struct {
	// Install is stats of named cache mounting before the task.
	Install model.OperationStats `json:"install,omitempty"`
	// Uninstall is stats of named cache unmounting after the task.
	Uninstall model.OperationStats `json:"uninstall,omitempty"`
}

func (r *TaskUpdateRequest) ExtractSession() []byte { return r.Session }
func (r *TaskUpdateRequest) ExtractDebugRequest() any {
	rv := *r
	rv.Session = nil
	rv.Output = nil
	return &rv
}

// TaskUpdateResponse is returned by the server.
type TaskUpdateResponse struct {
	// MustStop indicates whether the bot should stop the current task.
	MustStop bool `json:"must_stop"`

	// StopReason is the reason why the task must stop.
	// Should only be set when MustStop is true.
	StopReason string `json:"stop_reason,omitempty"`

	// OK indicates whether the update was processed successfully.
	OK bool `json:"ok"`

	// Session is a serialized bot session proto.
	//
	// If not empty, contains the refreshed session.
	Session []byte `json:"session,omitempty"`
}

// TaskUpdate implements handler that collects task state updates.
//
// Called by bots to report the tasks they are running to server.
func (srv *BotAPIServer) TaskUpdate(ctx context.Context, body *TaskUpdateRequest, r *botsrv.Request) (botsrv.Response, error) {
	if isTaskCompletion(body) {
		return srv.completeTask(ctx, body, r)
	}

	return srv.updateTask(ctx, body, r)
}

func (srv *BotAPIServer) updateTask(ctx context.Context, body *TaskUpdateRequest, r *botsrv.Request) (*TaskUpdateResponse, error) {
	taskUpdate, err := srv.processTaskUpdate(ctx, body, r)
	if errors.Is(err, wrongTaskIDErr) {
		logging.Warningf(ctx, "The bot is not associated with this task on the server")
		return &TaskUpdateResponse{MustStop: true, OK: false, StopReason: err.Error()}, nil
	}
	if err != nil {
		return nil, err
	}

	var outcome *tasks.UpdateTxnOutcome
	err = datastore.RunInTransaction(ctx, func(ctx context.Context) (err error) {
		outcome, err = srv.tasksManager.UpdateTxn(ctx, taskUpdate)
		return err
	}, nil)
	if err != nil {
		return nil, err
	}

	session, err := srv.updateBotSession(ctx, r.Session)
	if err != nil {
		return nil, err
	}

	resp := &TaskUpdateResponse{
		OK:         true,
		MustStop:   outcome.MustStop,
		StopReason: outcome.StopReason,
		Session:    session,
	}

	// Periodically update the bot's LastSeen.
	now := clock.Now(ctx).UTC()
	if now.Sub(r.BotLastSeen) < botLastSeenUpdateInterval {
		// LastSeen is fresh enough. Skip botinfo.Update.
		return resp, nil
	}

	update := &botinfo.Update{
		BotID:         r.Session.BotId,
		EventDedupKey: body.RequestUUID,
		EventType:     model.BotEventTaskUpdate,
		TasksManager:  srv.tasksManager,
		CallInfo: botCallInfo(ctx, &botinfo.CallInfo{
			SessionID: r.Session.SessionId,
		}),
	}

	if err := srv.submitUpdate(ctx, update); err != nil {
		if status.Code(err) != codes.Unknown {
			return nil, err
		}
		logging.Errorf(ctx, "Failed to update bot info: %s", err)
		return nil, status.Errorf(codes.Internal, "failed to update bot info %q", r.Session.BotId)
	}

	return resp, nil
}

func (srv *BotAPIServer) processTaskUpdate(ctx context.Context, body *TaskUpdateRequest, r *botsrv.Request) (*tasks.UpdateOp, error) {
	tr, err := validateTaskUpdateRequest(ctx, body, r)
	if err != nil {
		return nil, err
	}

	return &tasks.UpdateOp{
		Request:          tr,
		BotID:            r.Session.BotId,
		Output:           body.Output,
		OutputChunkStart: body.OutputChunkStart,
	}, nil
}

func (srv *BotAPIServer) completeTask(ctx context.Context, body *TaskUpdateRequest, r *botsrv.Request) (*TaskUpdateResponse, error) {
	taskCompletion, err := srv.processTaskCompletion(ctx, body, r)
	if errors.Is(err, wrongTaskIDErr) {
		logging.Warningf(ctx, "The bot is not associated with this task on the server")
		return &TaskUpdateResponse{OK: false}, nil
	}
	if err != nil {
		return nil, err
	}

	update := &botinfo.Update{
		BotID:         r.Session.BotId,
		EventDedupKey: body.RequestUUID,
		TasksManager:  srv.tasksManager,
		Prepare: func(ctx context.Context, _ *model.BotInfo) (*botinfo.PrepareOutcome, error) {
			outcome, err := srv.tasksManager.CompleteTxn(ctx, taskCompletion)
			if err != nil {
				return nil, err
			}
			return &botinfo.PrepareOutcome{Proceed: true, EventType: outcome.BotEventType}, nil
		},
		CallInfo: botCallInfo(ctx, &botinfo.CallInfo{
			SessionID: r.Session.SessionId,
		}),
	}

	if err := srv.submitUpdate(ctx, update); err != nil {
		if status.Code(err) != codes.Unknown {
			return nil, err
		}
		return nil, status.Errorf(codes.Internal, "failed to update the task: %s", err)
	}

	session, err := srv.updateBotSession(ctx, r.Session)
	if err != nil {
		return nil, err
	}
	return &TaskUpdateResponse{OK: true, Session: session}, nil
}

// processTaskUpdate validates body and prepares the parameters to perform the task update.
//
// When success, returns the prepared tasks.UpdateOp (for output updates) or
// tasks.CompleteOp (for task completion). Otherwise return a grpc error.
func (srv *BotAPIServer) processTaskCompletion(ctx context.Context, body *TaskUpdateRequest, r *botsrv.Request) (*tasks.CompleteOp, error) {
	tr, err := validateTaskUpdateRequest(ctx, body, r)
	if err != nil {
		return nil, err
	}

	var perfStats *model.PerformanceStats
	if body.BotOverhead != nil {
		perfStats = &model.PerformanceStats{
			BotOverheadSecs:      *body.BotOverhead,
			CacheTrim:            body.CacheTrimStats,
			PackageInstallation:  body.CIPDStats,
			NamedCachesInstall:   body.NamedCachesStats.Install,
			NamedCachesUninstall: body.NamedCachesStats.Uninstall,
			Cleanup:              body.CleanupStats,
			IsolatedDownload:     body.IsolatedStats.Download,
			IsolatedUpload:       body.IsolatedStats.Upload,
		}
	}

	return &tasks.CompleteOp{
		Request:          tr,
		BotID:            r.Session.BotId,
		PerformanceStats: perfStats,
		Canceled:         body.Canceled,
		CASOutputRoot:    body.CASOutputRoot,
		CIPDPins:         body.CIPDPins,
		CostUSD:          body.CostUSD,
		Duration:         body.Duration,
		ExitCode:         body.ExitCode,
		HardTimeout:      body.HardTimeout,
		IOTimeout:        body.IOTimeout,
		Output:           body.Output,
		OutputChunkStart: body.OutputChunkStart,
	}, nil
}

// Bump the session expiry.
func (srv *BotAPIServer) updateBotSession(ctx context.Context, session *internalspb.Session) ([]byte, error) {
	session.DebugInfo = botsession.DebugInfo(ctx, srv.version)
	session.Expiry = timestamppb.New(clock.Now(ctx).Add(botsession.Expiry))
	newSession, err := botsession.Marshal(session, srv.hmacSecret)
	if err != nil {
		return nil, status.Errorf(codes.Internal, "fail to marshal session proto: %s", err)
	}
	return newSession, nil
}

func isTaskCompletion(body *TaskUpdateRequest) bool {
	return body.ExitCode != nil || body.Canceled || body.HardTimeout || body.IOTimeout
}

func validateTaskUpdateRequest(ctx context.Context, body *TaskUpdateRequest, r *botsrv.Request) (*model.TaskRequest, error) {
	tr, err := validateTaskID(ctx, body.TaskID, r.CurrentTaskID)
	if err != nil {
		return nil, err
	}

	if body.CostUSD < 0 {
		return nil, status.Errorf(codes.InvalidArgument, "negative cost %f", body.CostUSD)
	}
	if body.Duration != nil && *body.Duration < 0 {
		return nil, status.Errorf(codes.InvalidArgument, "negative duration %f", *body.Duration)
	}
	if (body.Duration == nil) != (body.ExitCode == nil) {
		return nil, status.Errorf(
			codes.InvalidArgument,
			"expected to have both duration and exit code or neither, got duration %v, exit code %v",
			body.Duration, body.ExitCode)
	}

	if body.BotOverhead != nil && body.Duration == nil {
		return nil, status.Errorf(
			codes.InvalidArgument,
			"duration must be set when bot overhead is set")
	}
	return tr, nil
}
